{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.数据准备"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.0 加载库、定义函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import cv2\n",
    "from PIL import Image, ImageDraw\n",
    "import time\n",
    "import os\n",
    "import math\n",
    "import random \n",
    "# 导入pytorch相关的库\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_filePath_list(dirPath, partOfFileName=''):\n",
    "    \"\"\" 获取文件夹内的文件路径，返回类型为列表\"\"\"\n",
    "    all_fileName_list = next(os.walk(dirPath))[2]\n",
    "    fileName_list = [k for k in all_fileName_list if partOfFileName in k]\n",
    "    filePath_list = [os.path.join(dirPath, k) for k in fileName_list]\n",
    "    return filePath_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "# 实时更新进度条\n",
    "def print_flush(print_string):\n",
    "    print(print_string, end='\\r')\n",
    "    sys.stdout.flush()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.1 定义数据加载器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1172, 293)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "dirPath = '../resources/modified_jpgs/'\n",
    "imageFilePath_list = get_filePath_list(dirPath, '.jpg')\n",
    "N = len(imageFilePath_list)\n",
    "index_1dArray = np.arange(N)\n",
    "trainIndex_1dArray, testIndex_1dArray = train_test_split(index_1dArray, test_size=0.2)\n",
    "len(trainIndex_1dArray), len(testIndex_1dArray)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "category_list = ['backgroud', 'keyPoint_1', 'keyPoint_2']\n",
    "id2name_dict = {a:b for a,b in enumerate(category_list)}\n",
    "name2id_dict = {b:a for a,b in enumerate(category_list)}\n",
    "\n",
    "from xml.etree import ElementTree as ET\n",
    "def get_label(xmlFilePath):\n",
    "    if not os.path.exists(xmlFilePath):\n",
    "        return []\n",
    "    with open(xmlFilePath, encoding='utf8') as file:\n",
    "        fileContent = file.read()\n",
    "    root = ET.XML(fileContent)\n",
    "    object_list = root.findall('object')\n",
    "    classId_list = []\n",
    "    box_list = []\n",
    "    for object_item in object_list:\n",
    "        name = object_item.find('name')\n",
    "        className = name.text\n",
    "        classId = name2id_dict[className]\n",
    "        classId_list.append(classId)\n",
    "        bndbox = object_item.find('bndbox')\n",
    "        xmin = int(bndbox.find('xmin').text)\n",
    "        ymin = int(bndbox.find('ymin').text)\n",
    "        xmax = int(bndbox.find('xmax').text)\n",
    "        ymax = int(bndbox.find('ymax').text)\n",
    "        box = (xmin, ymin, xmax, ymax)\n",
    "        box_list.append(box)    \n",
    "    label = (box_list, classId_list)\n",
    "    return label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "读取全部图片到内存中, 总共耗时23.2949秒\n"
     ]
    }
   ],
   "source": [
    "def get_one_sample(imageFilePath):\n",
    "    image = Image.open(imageFilePath)\n",
    "    image_3dArray = np.array(image)\n",
    "    xmlFilePath = imageFilePath[:-3] + 'xml'\n",
    "    label = get_label(xmlFilePath)\n",
    "    return image_3dArray, label\n",
    "\n",
    "startTime = time.time()\n",
    "all_label_list = []\n",
    "all_images_4dArray = np.zeros((N, 1920, 320, 3), dtype='uint8')\n",
    "# 读取全部图片到内存中，会占内存2.52G多，即1465 * 1920 * 320 * 3 / (2**10)=2.5148\n",
    "for i, imageFilePath in enumerate(imageFilePath_list):\n",
    "    image_3dArray, label = get_one_sample(imageFilePath)\n",
    "    all_images_4dArray[i] = image_3dArray\n",
    "    all_label_list.append(label)\n",
    "usedTime = time.time() - startTime\n",
    "print('读取全部图片到内存中, 总共耗时%.4f秒' %usedTime)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_batch_sample(batchIndex_1dArray):\n",
    "    startTime = time.time()\n",
    "    batch_images_4dArray = all_images_4dArray[batchIndex_1dArray]\n",
    "    batch_label_list = [all_label_list[k] for k in batchIndex_1dArray]\n",
    "    x = torch.ByteTensor(batch_images_4dArray)\n",
    "    x = x.to('cuda')\n",
    "    x = x.permute(0,3,1,2).float()\n",
    "    return x, batch_label_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "import threading\n",
    "import queue\n",
    "\n",
    "\n",
    "train_N = len(trainIndex_1dArray)    \n",
    "class DataLoader(threading.Thread):\n",
    "    def __init__(self, batch_size=32, shuffle=True):\n",
    "        \"\"\" self.queue 用于放批次数据(输出)的队列\"\"\"\n",
    "        super(DataLoader, self).__init__()\n",
    "        self.queue = queue.Queue()\n",
    "        self.batch_size = batch_size\n",
    "        self.shuffle = shuffle\n",
    "        self.is_stopped = False\n",
    "        self.batch_index = 0\n",
    "        self.epoch_size = math.ceil(train_N/batch_size)\n",
    "        self.start()\n",
    "    \n",
    "    def run(self):\n",
    "        while not self.is_stopped:\n",
    "            if self.queue.qsize() < 3:\n",
    "                if self.shuffle and self.batch_index%self.epoch_size==0:\n",
    "                    random.shuffle(trainIndex_1dArray)\n",
    "                    self.batch_index = 0\n",
    "                start_index = self.batch_index * self.batch_size\n",
    "                end_index = (self.batch_index + 1) * self.batch_size\n",
    "                batchIndex_1dArray = trainIndex_1dArray[start_index: end_index]\n",
    "                put_tuple = get_batch_sample(batchIndex_1dArray)\n",
    "                self.queue.put(put_tuple)\n",
    "                self.batch_index += 1\n",
    "            time.sleep(0.001)\n",
    "    \n",
    "    def get_batch(self):\n",
    "        return self.queue.get()\n",
    "    \n",
    "    def __del__(self):\n",
    "        self.is_stopped = True\n",
    "        while not self.queue.empty():\n",
    "            self.queue.get()\n",
    "\n",
    "batch_size = 32            \n",
    "train_loader = DataLoader(batch_size, shuffle=True)    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_loader.queue.qsize()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.搭建神经网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BasicConv(nn.Module):\n",
    "    \"\"\" 基础卷积组件\"\"\"\n",
    "    def __init__(self, in_channels, out_channels, kernel_size, stride=1, relu=True, bn=True):\n",
    "        super(BasicConv, self).__init__()\n",
    "        padding = (kernel_size-1) // 2\n",
    "        self.bn = nn.BatchNorm2d(in_channels) if bn else None\n",
    "        self.relu = nn.ReLU() if relu else None\n",
    "        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        if self.bn:\n",
    "            x = self.bn(x)\n",
    "        if self.relu:\n",
    "            x = self.relu(x)\n",
    "        x = self.conv(x)\n",
    "        return x\n",
    "\n",
    "    \n",
    "class Backbone(nn.Module):\n",
    "    \"\"\" 骨干网络\"\"\"\n",
    "    def __init__(self):\n",
    "        super(Backbone, self).__init__()\n",
    "        self.conv1_1 = BasicConv(3, 8, kernel_size=3, stride=2, relu=False) # 3 * 3 * 3 * 8 / 2\n",
    "        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.conv2_1 = BasicConv(8, 16, kernel_size=3, stride=2) # 8 * 3 * 3 * 16 / 2\n",
    "        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.conv3_1 = BasicConv(16, 32, kernel_size=3, stride=2) # 16 * 3 * 3 * 32 / 2\n",
    "        \n",
    "    def forward(self, x):\n",
    "        \"\"\" x.shape : [N, 3, 1920, 320]\"\"\"\n",
    "        x = self.conv1_1(x)   # [N, 8, 960, 160]\n",
    "        x = self.pool1(x)     # [N, 8, 480, 80]\n",
    "        x = self.conv2_1(x)   # [N, 16, 240, 40]\n",
    "        x = self.pool2(x)     # [N, 16, 120, 20]\n",
    "        x = self.conv3_1(x)   # [N, 32, 60, 10]\n",
    "        return x    \n",
    "    \n",
    "\n",
    "class Net(nn.Module):\n",
    "    \"\"\" 基于骨干网络的检测网络\"\"\"\n",
    "    def __init__(self, class_quantity=3):\n",
    "        super(Net, self).__init__()\n",
    "        self.class_quantity = class_quantity\n",
    "        self.backbone = Backbone()\n",
    "        self.prediction_conv = nn.Conv2d(32, 2+class_quantity, kernel_size=3, padding=1)\n",
    "        self.softmax = nn.Softmax(dim=1)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        \"\"\" x.shape : [N, 3, 1920, 320]\"\"\"\n",
    "        x = self.backbone(x)         # [N, 32, 60, 10]\n",
    "        x = self.prediction_conv(x)  # [N, 5, 60, 10]\n",
    "        x = x.permute(0,2,3,1)       # [N, 60, 10, 5]\n",
    "        x = x.reshape(x.size(0), -1, x.size(-1)) #[N, 600, 5]\n",
    "        return x\n",
    "    \n",
    "net = Net()    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 16.,  16.],\n",
       "        [ 48.,  16.],\n",
       "        [ 80.,  16.],\n",
       "        [112.,  16.],\n",
       "        [144.,  16.]])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_priorCenter_2d():\n",
    "    priorCenter_list = []\n",
    "    for i in range(60):\n",
    "        for j in range(10):\n",
    "            x = (j + 0.5) * 32\n",
    "            y = (i + 0.5) * 32\n",
    "            priorCenter = x, y\n",
    "            priorCenter_list.append(priorCenter)\n",
    "    priorCenter_2d = torch.Tensor(priorCenter_list)\n",
    "    return priorCenter_2d\n",
    "    \n",
    "priorCenter_2d = get_priorCenter_2d()         \n",
    "priorCenter_2d[:5]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.定义损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiBoxLoss(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MultiBoxLoss, self).__init__()\n",
    "        self.locationLoss_function = nn.MSELoss(reduction='mean')\n",
    "        self.confidenceLoss_function = nn.CrossEntropyLoss(reduction='mean')\n",
    "        \n",
    "    def match(self, label_list):\n",
    "        N = len(label_list)\n",
    "        N_prior = len(priorCenter_2d)\n",
    "        gtOffset_3d = torch.Tensor(N, N_prior, 2)\n",
    "        gtClassId_2d = torch.zeros(N, N_prior).long()\n",
    "        for index, label in enumerate(label_list):\n",
    "            box_list, classId_list = label\n",
    "            for box, classId in zip(box_list, classId_list):\n",
    "                xmin, ymin, xmax, ymax = box\n",
    "                center_x = (xmin+xmax) // 2\n",
    "                center_y = (ymin+ymax) // 2\n",
    "                i = center_y // 32\n",
    "                j = center_x // 32\n",
    "                k = i * 10 + j\n",
    "                priorCenter_1d = priorCenter_2d[k]\n",
    "                gtBox_1d = torch.Tensor([center_x, center_y])\n",
    "                gtOffset_1d = gtBox_1d - priorBox_1d\n",
    "                gtOffset_3d[index][k] = gtOffset_1d\n",
    "                gtClassId_2d[index][k] = classId\n",
    "        gtOffset_3d = gtOffset_3d.to('cuda')\n",
    "        gtClassId_2d = gtClassId_2d.to('cuda')\n",
    "        return gtOffset_3d, gtClassId_2d\n",
    "    \n",
    "    def forward(self, predictions_3d, label_list):\n",
    "        pOffset_3d = predictions_3d[:, :, :2]\n",
    "        pConfidence_3d = predictions_3d[:, :, 2:]\n",
    "        gtOffset_3d, gtClassId_2d = self.match(label_list)\n",
    "        # 定位误差只计算正样本\n",
    "        positive_2d = gtClassId_2d > 0\n",
    "        positive_pOffset_2d = pOffset_3d[positive_2d]\n",
    "        positive_gtOffset_2d = gtOffset_3d[positive_2d]\n",
    "        location_loss = self.locationLoss_function(\n",
    "            positive_pOffset_2d, positive_gtOffset_2d)\n",
    "              \n",
    "        # 置信度误差计算正样本和2个负样本\n",
    "        N_negtive = 2\n",
    "        afterSoftmaxConf_3d = F.softmax(pConfidence_3d, dim=2)\n",
    "        negtiveConfidence_2d = afterSoftmaxConf_3d[..., 0]\n",
    "        index_2d = negtiveConfidence_2d.sort(1)[1]\n",
    "        rank_2d = index_2d.sort(1)[1]\n",
    "        negtive_2d = rank_2d < N_negtive\n",
    "        isSelected_2d = positive_2d + negtive_2d\n",
    "        pConfidence_2d = pConfidence_3d[isSelected_2d]\n",
    "        gtClassId_1d = gtClassId_2d[isSelected_2d]\n",
    "        confidence_loss = self.confidenceLoss_function(\n",
    "            pConfidence_2d, gtClassId_1d)\n",
    "              \n",
    "        return location_loss, confidence_loss\n",
    "    \n",
    "criterion = MultiBoxLoss(priorBox_2d)    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.在训练集上训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 59 step: 37/37 loss:0.161226 location_loss:0.160827 confidence_loss:0.000020\n",
      "训练过程总共耗时233.8997秒\n"
     ]
    }
   ],
   "source": [
    "epochs = 60\n",
    "device = torch.device('cuda')\n",
    "net.to(device)\n",
    "train_N = len(trainIndex_1dArray)\n",
    "epoch_size = math.ceil(train_N / batch_size)\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=1e-2)\n",
    "milestone_list = [10 * k for k in range(1, epochs//10)]\n",
    "scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestone_list, gamma=0.5)\n",
    "\n",
    "startTime = time.time()\n",
    "for epoch in range(epochs):\n",
    "    if  epoch > 0:\n",
    "        scheduler.step()\n",
    "    for step in range(1, epoch_size+1):\n",
    "        x, label_list = train_loader.get_batch()\n",
    "        predictions_3d = net(x)\n",
    "        location_loss, confidence_loss = criterion(predictions_3d, label_list)\n",
    "        loss = location_loss + confidence_loss * 20\n",
    "        loss_value = loss.item()\n",
    "        locationLoss_value = location_loss.item()\n",
    "        confidenceLoss_value = confidence_loss.item()\n",
    "        print_string = 'epoch: %d step: %d/%d loss:%.6f location_loss:%.6f confidence_loss:%.6f' %(\n",
    "            epoch, step, epoch_size, loss_value, locationLoss_value, confidenceLoss_value)\n",
    "        print_flush(print_string)\n",
    "        # 反向传播\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "usedTime = time.time() - startTime\n",
    "print('\\n训练过程总共耗时%.4f秒' %usedTime)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_loader.queue.qsize()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5.在测试集上测试模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.1 对比单张图片的标注数据、预测数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def detect(image, verbose=True):\n",
    "    startTime = time.time()\n",
    "    images_4dArray = np.expand_dims(np.array(image), 0)\n",
    "    with torch.no_grad():\n",
    "        x = torch.ByteTensor(images_4dArray).to(device)\n",
    "        x = x.permute(0,3,1,2).float()\n",
    "        predictions_3d = net(x)\n",
    "    prediction_2d = predictions_3d[0]\n",
    "    pOffset_2d = prediction_2d[:, :2]\n",
    "    pConfidence_2d = F.softmax(prediction_2d[:, 2:], 1)\n",
    "    pBox_list = []\n",
    "    pClassId_list = []\n",
    "    for classId in [1, 2]:\n",
    "        classConfidence_1d = pConfidence_2d[:, classId]\n",
    "        maxConfidence = classConfidence_1d.max().item()\n",
    "        if maxConfidence > 0.5:\n",
    "            pClassId_list.append(classId)\n",
    "            index = classConfidence_1d.argmax()\n",
    "            pOffset = pOffset_2d[index]\n",
    "            pCenter = priorBox_2d[index] + pOffset.to('cpu')\n",
    "            center_x, center_y = pCenter\n",
    "            center_x, center_y = int(center_x), int(center_y)\n",
    "            min_x = max(center_x - 15, 0)\n",
    "            max_x = min(center_x + 15, 320-1)\n",
    "            min_y = max(center_y - 15, 0)\n",
    "            max_y = min(center_y + 15, 1920-1)\n",
    "            box = min_x, min_y, max_x, max_y\n",
    "            pBox_list.append(box)\n",
    "            \n",
    "    usedTime = time.time() - startTime\n",
    "    if verbose:\n",
    "        print('预测耗时%.4f秒' %usedTime)\n",
    "    return pBox_list, pClassId_list    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "标注数据: ([(99, 508, 129, 538)], [2])\n",
      "预测耗时0.0060秒\n",
      "预测数据: ([(98, 507, 128, 537)], [2])\n"
     ]
    }
   ],
   "source": [
    "def test_1():\n",
    "    index = random.choice(testIndex_1dArray)\n",
    "    imageFilePath = imageFilePath_list[index]\n",
    "    image_3dArray, label = get_one_sample(imageFilePath)\n",
    "    print('标注数据:', label)\n",
    "    pResult = detect(image_3dArray)\n",
    "    print('预测数据:', pResult)\n",
    "    \n",
    "test_1()    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.2 查看模型预测画框后的图"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "预测耗时0.0070秒\n"
     ]
    }
   ],
   "source": [
    "def test_2():\n",
    "    index = random.choice(testIndex_1dArray)\n",
    "    imageFilePath = imageFilePath_list[index]\n",
    "    image = Image.open(imageFilePath)\n",
    "    image_3dArray = np.array(image)\n",
    "    pBox_list, pClassId_list = detect(image_3dArray)\n",
    "    for box, classId in zip(pBox_list, pClassId_list):\n",
    "        if classId == 1:\n",
    "            color = [0, 255, 0]\n",
    "        else:\n",
    "            color = [255, 0, 0]\n",
    "        x1, y1, x2, y2 = box\n",
    "        leftTop_point = x1, y1\n",
    "        rightBottom_point = x2, y2\n",
    "        cv2.rectangle(image_3dArray, leftTop_point, rightBottom_point, color, 3)    \n",
    "    drawed_image = Image.fromarray(image_3dArray)\n",
    "    drawed_image.show()\n",
    "    \n",
    "test_2()    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5.3 模型评价: 整个测试集上的准确率、召回率"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 5.3.1 通过例子理解准确率、召回率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.metrics import precision_recall_fscore_support\n",
    "\n",
    "def eval_model(true_y, predicted_y, category_list):\n",
    "    p, r, f1, s = precision_recall_fscore_support(true_y, predicted_y)\n",
    "    if len(p) == len(category_list) -1:\n",
    "        # 最极端的情况: 所有测试样例都正确，即没有负样本\n",
    "        category_list = category_list[1:]\n",
    "    category_1dArray = np.array(category_list)\n",
    "    df = pd.DataFrame([category_1dArray, p, r, f1, s]).T\n",
    "    df.columns = ['Label', 'Precision', 'Recall', 'F1', 'Support']\n",
    "    # 计算总体的平均Precision, Recall, F1, Support\n",
    "    all_label = '总体'\n",
    "    all_p = np.average(p, weights=s)\n",
    "    all_r = np.average(r, weights=s)\n",
    "    all_f1 = np.average(f1, weights=s)\n",
    "    all_s = np.sum(s)\n",
    "    row = [all_label, all_p, all_r, all_f1, all_s]\n",
    "    df.loc[999] = row\n",
    "    # 设置Precision、Recall、F1这3列显示4位小数\n",
    "    column_list = ['Precision', 'Recall', 'F1']\n",
    "    df[column_list] = df[column_list].applymap(lambda x: '%.4f' %x)\n",
    "    return df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Label</th>\n",
       "      <th>Precision</th>\n",
       "      <th>Recall</th>\n",
       "      <th>F1</th>\n",
       "      <th>Support</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>background</td>\n",
       "      <td>0.0000</td>\n",
       "      <td>0.0000</td>\n",
       "      <td>0.0000</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>keyPoint_1</td>\n",
       "      <td>0.5000</td>\n",
       "      <td>0.3333</td>\n",
       "      <td>0.4000</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>keyPoint_2</td>\n",
       "      <td>0.5000</td>\n",
       "      <td>0.5000</td>\n",
       "      <td>0.5000</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>999</th>\n",
       "      <td>总体</td>\n",
       "      <td>0.3571</td>\n",
       "      <td>0.2857</td>\n",
       "      <td>0.3143</td>\n",
       "      <td>7</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          Label Precision  Recall      F1 Support\n",
       "0    background    0.0000  0.0000  0.0000       2\n",
       "1    keyPoint_1    0.5000  0.3333  0.4000       3\n",
       "2    keyPoint_2    0.5000  0.5000  0.5000       2\n",
       "999          总体    0.3571  0.2857  0.3143       7"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gt_y = [1, 1, 1, 0, 2, 2, 0] \n",
    "p_y = [1, 0, 0, 1, 2, 0, 2]\n",
    "category_list = ['background', 'keyPoint_1', 'keyPoint_2']\n",
    "eval_model(gt_y, p_y, category_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 5.3.2 获取测试集所有样本的真实值、预测值"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "gtLabel_list = []\n",
    "pLabel_list = []\n",
    "for index in testIndex_1dArray:\n",
    "    gtLabel = all_label_list[index]\n",
    "    gtLabel_list.append(gtLabel)\n",
    "    image_3dArray = all_images_4dArray[index]\n",
    "    pLabel = detect(image_3dArray, verbose=False)\n",
    "    pLabel_list.append(pLabel)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Label</th>\n",
       "      <th>Precision</th>\n",
       "      <th>Recall</th>\n",
       "      <th>F1</th>\n",
       "      <th>Support</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>keyPoint_1</td>\n",
       "      <td>1.0000</td>\n",
       "      <td>1.0000</td>\n",
       "      <td>1.0000</td>\n",
       "      <td>161</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>keyPoint_2</td>\n",
       "      <td>1.0000</td>\n",
       "      <td>1.0000</td>\n",
       "      <td>1.0000</td>\n",
       "      <td>173</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>999</th>\n",
       "      <td>总体</td>\n",
       "      <td>1.0000</td>\n",
       "      <td>1.0000</td>\n",
       "      <td>1.0000</td>\n",
       "      <td>334</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "          Label Precision  Recall      F1 Support\n",
       "0    keyPoint_1    1.0000  1.0000  1.0000     161\n",
       "1    keyPoint_2    1.0000  1.0000  1.0000     173\n",
       "999          总体    1.0000  1.0000  1.0000     334"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def test_3(gtLabel_list, pLabel_list):\n",
    "    gt_y = []\n",
    "    p_y = []\n",
    "    for gtLabel, pLabel in zip(gtLabel_list, pLabel_list):\n",
    "        gtBox_list, gtClassId_list = gtLabel\n",
    "        pBox_list, pClassId_list = pLabel\n",
    "        pMatched_list = [False] * len(pBox_list)\n",
    "        # 先遍历真实值\n",
    "        for gtBox, gtClassId in zip(gtBox_list, gtClassId_list):\n",
    "            if gtClassId in pClassId_list:\n",
    "                index = pClassId_list.index(gtClassId)\n",
    "                pBox = pBox_list[index]\n",
    "                diffValue_1dArray = np.subtract(gtBox, pBox)\n",
    "                absValue_1dArray = np.abs(diffValue_1dArray)\n",
    "                diffSum = absValue_1dArray.sum()\n",
    "                if diffSum < 20:\n",
    "                    gt_y.append(gtClassId)\n",
    "                    p_y.append(gtClassId)\n",
    "                    pMatched_list[index] = True\n",
    "                    continue\n",
    "            gt_y.append(gtClassId)\n",
    "            p_y.append(0)\n",
    "        # 然后遍历预测值中未被匹配到的, 即背景被预测为正样本\n",
    "        for index, matched in enumerate(pMatched_list):\n",
    "            if not matched:\n",
    "                pClassId = pClassId_list[index]\n",
    "                gt_y.append(0)\n",
    "                p_y.append(pClassId)\n",
    "    category_list = ['background', 'keyPoint_1', 'keyPoint_2']\n",
    "    df = eval_model(gt_y, p_y, category_list)\n",
    "    return df\n",
    "    \n",
    "test_3(gtLabel_list, pLabel_list)    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6.模型使用"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.1 模型保存"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "dirPath = '../resources/trained_weights'\n",
    "if not os.path.isdir(dirPath):\n",
    "    os.makedirs(dirPath)\n",
    "pthFileName = 'ckpt.pth'\n",
    "pthFilePath = os.path.join(dirPath, pthFileName)\n",
    "torch.save(net.state_dict(), pthFilePath)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.2  模型加载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "pthFilePath = '../resources/trained_weights/ckpt.pth'\n",
    "state_dict = torch.load(pthFilePath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Net(\n",
       "  (backbone): Backbone(\n",
       "    (conv1_1): BasicConv(\n",
       "      (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv): Conv2d(3, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    )\n",
       "    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (conv2_1): BasicConv(\n",
       "      (bn): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU()\n",
       "      (conv): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    )\n",
       "    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
       "    (conv3_1): BasicConv(\n",
       "      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU()\n",
       "      (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))\n",
       "    )\n",
       "  )\n",
       "  (prediction_conv): Conv2d(32, 5, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
       "  (softmax): Softmax(dim=1)\n",
       ")"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net.load_state_dict(state_dict)\n",
    "device = torch.device('cuda')\n",
    "net.to(device)\n",
    "net.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.3显存占用分析"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "       BatchNorm2d-1         [-1, 3, 1920, 320]               6\n",
      "            Conv2d-2          [-1, 8, 960, 160]             224\n",
      "         BasicConv-3          [-1, 8, 960, 160]               0\n",
      "         MaxPool2d-4           [-1, 8, 480, 80]               0\n",
      "       BatchNorm2d-5           [-1, 8, 480, 80]              16\n",
      "              ReLU-6           [-1, 8, 480, 80]               0\n",
      "            Conv2d-7          [-1, 16, 240, 40]           1,168\n",
      "         BasicConv-8          [-1, 16, 240, 40]               0\n",
      "         MaxPool2d-9          [-1, 16, 120, 20]               0\n",
      "      BatchNorm2d-10          [-1, 16, 120, 20]              32\n",
      "             ReLU-11          [-1, 16, 120, 20]               0\n",
      "           Conv2d-12           [-1, 32, 60, 10]           4,640\n",
      "        BasicConv-13           [-1, 32, 60, 10]               0\n",
      "         Backbone-14           [-1, 32, 60, 10]               0\n",
      "           Conv2d-15            [-1, 5, 60, 10]           1,445\n",
      "================================================================\n",
      "Total params: 7,531\n",
      "Trainable params: 7,531\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 7.03\n",
      "Forward/backward pass size (MB): 43.53\n",
      "Params size (MB): 0.03\n",
      "Estimated Total Size (MB): 50.59\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "from torchsummary import summary\n",
    "summary(net, (3,1920,320))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7.03125"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "1 * 1920 * 320 * 3 * 4 / (2**20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 205,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = 1920 * 320 * 3  + 8 * 960 * 160  + 8 * 480 * 80  + 16 * 240 * 40  + 16 * 120 * 20 + 32 * 60 * 10  + 5 * 60 * 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 206,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "13.707733154296875"
      ]
     },
     "execution_count": 206,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(a * 4) / (2**20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
